# -*- coding: utf-8 -*-
"""
Created on Tue Oct 31 15:50:09 2017
Python version: 2.7
@author: teng-kuei
Modified by Jenn Asmussen on Tue Mar 1 2022
"""
import os
import sys
import pandas as pd
from glob import glob
from shutil import copyfile,rmtree
from re import search
import shlex
import subprocess
import numpy as np
import pandas as pd
from sqlite3 import connect
from scipy.stats import mstats,combine_pvalues
from scipy.stats.distributions import binom
from statsmodels.stats import multitest


def create_folder_if_not_exists(dir_to_create):
    """
    Create new folder if it does not exist
    """
    if not os.path.exists(dir_to_create):
        os.makedirs(dir_to_create)


def is_number(s):
    """
    To determine if the input is floating number,
    Return True or False
    """
    try:
        float(s)
        return True
    except ValueError:
        return False
        
def get_SNV_by_NM(fn_EA_input): # shared by get_gene_profile, run_CI_NS, run_Freq
    """
    Read file in "ANNOVAR_EA" format,    
    Return dictionary of mutation by NMID, 
        keys = gene+"\t"+NMID, 
        values = (amino acid)substitution+"\t"+EA
    """
    d_geneNM_mut = {}
    EA_input_contents=open(fn_EA_input).readlines()
    for ln in EA_input_contents:
        ln=ln.strip().split("\t")
        SNV_list = ln[2].split(",")
        sub_type = ln[1].replace (" ", "_")
        for SNV in SNV_list:
            mut_data = SNV.split(":")
            geneNM = mut_data[0]+"\t"+mut_data[1]
            sub,EA = mut_data[-2],mut_data[-1]
            if geneNM not in d_geneNM_mut:
                d_geneNM_mut[geneNM]=[]
            d_geneNM_mut[geneNM].append("\t".join([sub_type,sub,EA]))
    return d_geneNM_mut        

def get_gene_profile(dir_ANNOVAR, dir_Gprofile):    
    """
    Step 1: Generate gene profile files
        (1) Read ANNOVAR_EA file and collect all mutations for each gene
        (2) Classify mutations for each gene
        (3) Generate "SNVs_count", showing the number of mutation in different mutation types for each gene
        (4) Generate "EAD_profiles", showing the distribution of EA score for missense mutation and the number of nonsense mutation for each gene
    """        
    def get_mut_convert_table():
        """
        Return a dictionary for classifying ANNOVAR mutation types
        """
        convert_table = "mut_type_convert.txt"
        df = pd.read_csv(convert_table,sep="\t")
        return dict(zip(df['ANNOVAR_mutation_name'].tolist(),df['Mutation_type'].tolist()))
    
    def sep_by_subtype(d_geneNM_mut,d_ANNOVARmut_mut,l_mut_type):
        """
        For each gene, classify all their mutation 
        Return dictionary: 
            key = gene+"\t"+NM
            value = dictionary of all SNVs by mutation type                            
        """
        d_geneNM_muttype = {}   # key = gene NM; value = classified mutation type 
        for geneNM, SNVs in d_geneNM_mut.iteritems():
            d_geneNM_muttype[geneNM] = {}
            d_SNV_by_type = {}  # classify SNV by mutation type for each gene
            for mt in l_mut_type:
                d_SNV_by_type[mt]=[]   
            for SNV in SNVs:
                sub_type = SNV.split("\t")[0]
                if sub_type in d_ANNOVARmut_mut:
                    mut_type = d_ANNOVARmut_mut[sub_type]
                    d_SNV_by_type[mut_type].append(SNV)
                else:
                    print "\t\tWarning: find other mutation type: "+sub_type
            d_geneNM_muttype[geneNM] = d_SNV_by_type # add dictionary of SNV as key to each gene
        return d_geneNM_muttype

    def get_NM_mut_count(dir_EA_SNV_out,cancer_type,d_geneNM_muttype,l_mut_type):
        """
        Write a file showing the number of mutation in different mutation types for each gene
        """
        with open(dir_EA_SNV_out+cancer_type+"_SNVs_count.tsv","w") as f:
            f.write("Gene\tNM\t"+"\t".join(l_mut_type)+"\ttotal\n")
            for geneNM in d_geneNM_muttype:
                d_mut = d_geneNM_muttype[geneNM]
                counts = [len(d_mut[mt]) for mt in l_mut_type]
                f.write("{}\t{}\t{}\n".format(geneNM,"\t".join(map(str,counts)),sum(counts)))

    def get_NM_EAprofile(dir_EA_bins_out,cancer_type,d_geneNM_muttype):
        """
        Write a file showing the distribution of EA score for missense mutation and the number of nonsense mutation for each gene
        """
        def get_EAs(missenses):
            """
            return EA score for missense mutation and convert to floating number
            """
            EAs = []
            nonEA = []
            for missense in missenses:
                EA=missense.split("\t")[-1]
                if is_number(EA):
                    EAs.append(float(EA))
                else:
                    nonEA.append(missense)
            return EAs,nonEA
            
        with open(dir_EA_bins_out+cancer_type+"_EAD_profiles.tsv","w") as f:
            f.write("Gene\tNM\t[0,10)\t[10,20)\t[20,30)\t[30,40)\t[40,50)\t[50,60)\t[60,70)\t[70,80)\t[80,90)\t[90,100]\tnonsense\n")
            for geneNM in d_geneNM_muttype:
                missenses = d_geneNM_muttype[geneNM]['missense']
                nonsenses = d_geneNM_muttype[geneNM]['nonsense']
                EAs,_ = get_EAs(missenses)
                binned_EA = np.histogram(EAs, range(0,101,10))[0]
                f.write("{}\t{}\t{}\n".format(geneNM,"\t".join(map(str,binned_EA)),len(nonsenses)))
    
    print "Step 1 : Generate gene profile files:"
    print "\t- Output folder: "+dir_Gprofile

    # Create output folder
    create_folder_if_not_exists(dir_Gprofile)
    
    # Get a list of ANNOVAR_EA file
    fn_EA_inputs = glob(dir_ANNOVAR+"*.ANNOVAR_EA")
    print "Input ANNOVAR files:", fn_EA_inputs
    print "dirANNOVAR:", dir_ANNOVAR
    
    for fn_EA_input in fn_EA_inputs:
        cancer_type=os.path.basename(fn_EA_input).split(".")[0]
        print "\t\t-"+cancer_type

        # Generate folders
        dir_EA_SNV_out = dir_Gprofile+"/SNVs_count/" 
        create_folder_if_not_exists(dir_EA_SNV_out)
        
        dir_EA_bins_out = dir_Gprofile+"/EAD_profiles/"
        create_folder_if_not_exists(dir_EA_bins_out)
        
        ### get SNVs by NM ###
        d_geneNM_mut = get_SNV_by_NM(fn_EA_input) # get dictionary of mutation by NMID
        
        ### separate SNVs by types ###
        d_ANNOVARmut_mut = get_mut_convert_table()
        l_mut_type = list(set(d_ANNOVARmut_mut.values()))
        d_geneNM_muttype = sep_by_subtype(d_geneNM_mut,d_ANNOVARmut_mut,l_mut_type)
        
        ### FILE 1: get gene mutation type count ###
        get_NM_mut_count(dir_EA_SNV_out,cancer_type,d_geneNM_muttype,l_mut_type)
        
        ### FILE 2: get gene mutation profile ###
        get_NM_EAprofile(dir_EA_bins_out,cancer_type,d_geneNM_muttype)


def run_CI_anaylsis(dir_ANNOVAR,dir_CIresult):
    """
    Step 2 : Generate CI files
        (1) Calculate sLOF/sGOF indexes with perl code "Isc_for_TCGA_variants"
        (2) Run CI analysis and get p-values for CI (missense+nonsense) and frequency
        (3) Correlate for multi-testing and get q-value
    """

    def run_Isc(dir_ANNOVAR,dir_sLOFsGOF):
        """
        Calculate sLOF/sGOF indexes with perl code "Isc_for_TCGA_variants"
        """
        perl_code = "/Isc_for_TCGA_variants.pl"
        cmd_seq = ("perl",os.getcwd()+perl_code,dir_ANNOVAR,dir_sLOFsGOF,os.getcwd())
        cmd_str = " ".join(cmd_seq)
        cmd = shlex.split(cmd_str)
        proc=subprocess.Popen(cmd, stdout=subprocess.PIPE)
        proc.wait()
        
    def get_gene_syn():
        """
        Return dictionary of gene synonym 
        """
        fn_syn = "gene_synonym.txt"    
        df = pd.read_csv(fn_syn,sep="\t",header=None)
        gene_syn = dict(zip(df[0],df[1]))
        return gene_syn
        
    def get_gene_NM_size():
        """
        Return two dictionaries:
            dictionary "gene_NM": key = gene, value = NM
            dictionary "gene_size": key = gene, value = Size (protein length)
        """
        fn = "gene_NM_size.tsv"
        df = pd.read_csv(fn,sep="\t")
        gene_NM = dict(zip(df['Gene'].tolist(),df['NM'].tolist()))
        gene_size = dict(zip(df['Gene'].tolist(),df['Size'].tolist()))
        return gene_NM,gene_size
    
    def run_CI_get_pvalues(fn_ANNOVARs,dir_CIresult,gene_syn,gene_size):
        """
        Run CI analysis and get p-values for CI-ns and frequency
        Write output files as "[cancer].temp"
        Return the list of "[cancer].temp" files
        """
    
        def get_gene_info(fn_ANNOVAR):
            """
            Return dictionaries of EA scores, missense, nonsense, silent, other mutations for each gene        
            """
            lns = [ln.strip().split("\t") for ln in open(fn_ANNOVAR).readlines()]
            gene_EA,gene_M,gene_S,gene_N,gene_O = {},{},{},{},{}
            for ln in lns:
                gene,mut_type,EA = ln[2].split(":")[0],ln[1],ln[2].split(":")[-1]
                if gene not in gene_EA:
                    gene_EA[gene],gene_M[gene],gene_S[gene],gene_N[gene],gene_O[gene]=[],0,0,0,0
                if mut_type == "nonsynonymous SNV":
                    gene_M[gene]+=1
                    if is_number(EA):
                        gene_EA[gene].append(float(EA))
                elif mut_type == "synonymous SNV":
                    gene_S[gene]+=1
                elif mut_type == "stopgain SNV":
                    gene_N[gene]+=1
                else:
                    gene_O[gene]+=1
            return gene_EA,gene_M,gene_S,gene_N,gene_O
    
        def get_mut_rate(gene_size,gene_O):
            """
            Return the average mutation rate of other mutations
            """
            tot_size,tot_O = 0,0
            for gene in gene_O:
                if gene in gene_size:
                    tot_size+=gene_size[gene]
                    tot_O+=gene_O[gene]
            return float(tot_O)/tot_size
            
        def get_random_for_NM(NM):
            """
            Return EA score and nonsense mutation for the input NM
            """
            rand_EA_NS = []
            db = connect('NM_gene_rand_ea.sql')
            cur = db.cursor()       
            cur.execute('''SELECT * FROM NM_gene_randEA_randEANS WHERE NM = ?''',(NM,))
            row = cur.fetchone() 
            if row:
                rand_EA_NS = [float(x) for x in row[3].split(',')]
            return rand_EA_NS
        

        
        
        fn_EAD_FQs = []
        for fn_ANNOVAR in fn_ANNOVARs:
            # get cancer type
            cancer = fn_ANNOVAR.split("/")[-1].split(".")[0]    
            print "\t\t-"+cancer
            
            # create temp output file for p-values
            fn_EAD_FQ = dir_CIresult+cancer+".temp" 
            fn_EAD_FQs.append(fn_EAD_FQ)
            gene_EA,gene_M,gene_S,gene_N,gene_O  = get_gene_info(fn_ANNOVAR)    # collect gene info as dictionaries         
            rate_O = get_mut_rate(gene_size,gene_O)   # calculate mutation rate of other mutation types
            
            with open(fn_EAD_FQ,"w") as f:
                f.write("Gene\tSize\t#Missense\t#Nonsense\t#Others\tCI-ns,p\tFQ,p\tCI,p\t[0,10)\t[10,20)\t[20,30)\t[30,40)\t[40,50)\t[50,60)\t[60,70)\t[70,80)\t[80,90)\t[90,100]\n")
                for gene in gene_EA:
                    
                    # Return gene synonym
                    if gene in gene_syn:
                        g_syn = gene_syn[gene]
                    else:
                        g_syn = gene
                    
                    # Collect EA scores
                    M_EA = gene_EA[gene]
                    binned_EA = map(str,np.histogram(M_EA, range(0,101,10))[0]) # Bin EA
                    
                    # Get CI-ns p-value
                    MN_EA = M_EA+[100.]*gene_N[gene]    # Add nonsense mutation as EA of 100
                    if g_syn in gene_NM:
                        MN_EA_rand = get_random_for_NM(gene_NM[g_syn])  # Get all random mutations for the given NM
                    else:
                        MN_EA_rand = []
                    if len(MN_EA)>0 and len(MN_EA_rand)>0:
                        ks_p = mstats.ks_twosamp(MN_EA,MN_EA_rand,alternative='less')[1]    # Run KS-test to compare the EA scores of observed mutations and random mutations
                    else:
                        ks_p = 1
                        
                    # Get frequency p-value
                    if gene in gene_size:
                        size = gene_size[gene]
                        fq_p_O = 1-binom.cdf(gene_O[gene],gene_size[gene],rate_O)   # Calculate frequency p-value of other mutations with binomial test
                    else:
                        size = "-"
                        fq_p_O =1
                    
                    # combine p-value
                    cb_p_O = combine_pvalues([ks_p,fq_p_O])[1]  # Combine CI-ns and frequency p-value
                    f.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format(g_syn,size,gene_M[gene],gene_N[gene],gene_O[gene],ks_p,fq_p_O,cb_p_O,"\t".join(binned_EA)))
        return fn_EAD_FQs

    def get_q_and_merge_files(dir_sLOFsGOF,fn_EAD_FQs):
        """
        Run multi-test correction and acquire q-values
        Merge results from CI and sLOF/sGOF
        Wrtie output files "[Cancer]_CI.tsv"
        """
        for fn_p in fn_EAD_FQs:
            cancer = fn_p.split("/")[-1].split(".")[0]
            df = pd.read_csv(fn_p,sep="\t")
            df = df[~df['Size'].isin(['-'])]
            p_columns = [c for c in df.columns.tolist() if c.endswith(",p")]
            for p_column in p_columns:
                q_column = p_column[:-1]+"q"
                #df[q_column] = stats.p_adjust(df[p_column].tolist(), method = 'BH')
                df[q_column] = multitest.multipletests(df[p_column].tolist(), method = 'fdr_bh')[1] # Convert p-values to q-values with FDR
    
            fn_sLOFsGOF = glob(dir_sLOFsGOF+cancer+".Isc")[0]
            df_sLOFGOF = pd.read_csv(fn_sLOFsGOF,sep="\t")
            df_join = pd.merge(df,df_sLOFGOF,on=['Gene'])
            
            columns = ["Gene","Size","#Missense","#Nonsense","#Others","CI-ns,p","CI-ns,q","FQ,p","FQ,q","CI,p","CI,q",'sLOF','sGOF',"[0,10)","[10,20)","[20,30)","[30,40)","[40,50)","[50,60)","[60,70)","[70,80)","[80,90)","[90,100]"]
            df_join[columns].to_csv(fn_p.split(".")[0]+"_CI.tsv",sep="\t",index=False)      
        
        # remove temp files
        map(os.remove,fn_EAD_FQs)
        rmtree(dir_sLOFsGOF)
        
    # Create output folder
    create_folder_if_not_exists(dir_CIresult)

    # Get a list of ANNOVAR_EA file  
    fn_ANNOVARs = glob(dir_ANNOVAR+"*.ANNOVAR_EA")
    
    # Calculate sLOF/sGOF indexes
    print "Step 2.1 : Generate CI files - sLOF/sGOF"
    dir_sLOFsGOF = dir_CIresult+"sLOFsGOF/"
    create_folder_if_not_exists(dir_sLOFsGOF)
    print "\t- Output folder: "+dir_sLOFsGOF 
    # Run Perl code to calculate sLOF/sGOF indexes
    run_Isc(dir_ANNOVAR,dir_sLOFsGOF)
    
    # Get CI(NS) p-value
    print "Step 2.2 : Generate CI files - Run CI analysis"
    print "\t- Output folder: "+ dir_CIresult
    
    # Run CI analysis to calculate p-values for CI(missense and nonsense), frequency

    # Collect gene info from dependency files
    gene_NM,gene_size = get_gene_NM_size()
    gene_syn = get_gene_syn()
    fn_EAD_FQs = run_CI_get_pvalues(fn_ANNOVARs,dir_CIresult,gene_syn,gene_size) 
    
    # Get q-values and merge results from CI and sLOF/sGOF indexes
    get_q_and_merge_files(dir_sLOFsGOF,fn_EAD_FQs)
    
def get_SMG(dir_CIresult,dir_SMG):
    """
    Step 3: Identify significantly mutated genes
        (1) Selection gene with CI-ns,p<0.05
        (2) Selection gene with CI,q<0.1
        (3) Selection gene with sLOF>=0.15 or sGOF>=0.1
    """
    print "Step 3 : Get significantly mutated genes"
    print "\t- Output folder: "+dir_SMG
    # Create output folder
    create_folder_if_not_exists(dir_SMG)
    
    fn_CIs = glob(dir_CIresult+"*_CI.tsv")
    for fn_CI in fn_CIs:
        cancer = fn_CI.split("/")[-1].split("_")[0]
        print "\t\t-"+cancer
        df_CI = pd.read_csv(fn_CI,sep="\t")
        df_CI = df_CI[df_CI['CI-ns,p']<0.05].reindex()  # Selection gene with CI-ns,p<0.05
        df_CI = df_CI[df_CI['CI,q']<0.1].reindex()      # Selection gene with CI,q<0.1
        df_CI = df_CI[~((df_CI['sLOF']<0.15)&(df_CI['sGOF']<0.1))].reindex()  # Selection gene with sLOF>=0.15 or sGOF>=0.1
        df_CI.to_csv(dir_SMG+cancer+".tsv",sep="\t",index=False)



## Run ##
if __name__ == "__main__":

    """
    ### Get input and output folders ###    
    get input folder and output folder from command line
    """
    #print "Usage: python [CohortInteg code name] [input folder] [output folder]"
    print "Usage: python [CohortInteg code name] [annovar_ea folder] [output folder]"

    dir_ANNOVAR,dir_output = sys.argv[1],sys.argv[2]

    if not  dir_output.endswith("/"):
        dir_output = dir_output+"/"
         
    # prepare output folders
    create_folder_if_not_exists(dir_output)    
    dir_Gprofile = dir_output+"Gene_profiles/"
    dir_CIresult = dir_output+"CI_results/"
    dir_SMG = dir_output+"Significant_genes/"

    """
    Step 1 : Generate gene profile files
    """
    get_gene_profile(dir_ANNOVAR,dir_Gprofile)
    
    """
    Step 2 : Generate CI files
    """
    run_CI_anaylsis(dir_ANNOVAR,dir_CIresult)
    
    """
    Step 3 : Get significant genes
    """
    get_SMG(dir_CIresult,dir_SMG)
    

